﻿/**

@author : fanwenjie
@data    : 2020-02-01
@note	: 根据Yann Lecun的论文《Gradient-based Learning Applied To Document Recognition》编写
@api	:

初始化设备上下文
ctx: 需要初始化的设备上下文
VkResult CreateDeviceContext(DeviceContext* ctx);

销毁设备上下文
ctx: 需要销毁的设备上下文
void DestroyDeviceContext(DeviceContext* ctx);

初始化训练缓存
ctx: 设备上下文
cache: 训练缓存
batchSize: 批量训练数
VkResult CreateTrainCache(DeviceContext* ctx, TrainCache* cache, const uint32_t batchSize);

销毁训练缓存
ctx: 设备上下文
cache: 训练缓存
void DestroyTrainCache(DeviceContext* ctx, TrainCache* cache);

从主存中加载模型到设备内存中
void LoadModel(DeviceContext* lenet, LeNet5* data);

将设备内存中的模型加载到主存中
void SaveModel(DeviceContext* lenet, LeNet5* data);

预测模型结果
ctx: 设备上下文
feature: 特征数据
uint32_t Predict(DeviceContext* ctx, Feature* feature);

批量训练模型
ctx: 设备上下文
cache: 训练缓存
feature: 参与训练的特征数据
label: 训练的标签
void TrainBatch(DeviceContext* ctx, TrainCache* cache, Feature* feature, uint32_t* label);
**/

#ifndef LENET_H
#define LENET_H
#include <vulkan/vulkan.h>
#include <stdlib.h>
#include <stdio.h>
#include <time.h>
#include <memory.h>
#include <math.h>
#include "public.h"

typedef struct LeNet5 LeNet5;
typedef struct Feature Feature;

typedef struct StorageBuffer
{
    VkBuffer Buffer;
    VkDeviceMemory Memory;
}StorageBuffer;


typedef struct DeviceContext
{
    VkInstance instance;
    VkDevice device;
    VkQueue queue;
    VkCommandPool commandPool;
    VkCommandBuffer commandBuffer;
    VkPipeline forwardPipeline[6];
    VkPipeline softmaxPipeline;
    VkPipeline backwardPipeline[6];
    VkPipeline deltaPipeline[4];
    VkPipeline updatePipeline;
    VkPipelineLayout pipelineLayout;
    VkDescriptorSetLayout descriptorSetLayout;
    VkDescriptorPool descriptorPool;
    VkDescriptorSet descriptorSet;
    StorageBuffer lenet, feature;
    VkMemoryBarrier barrier;
    VkPhysicalDeviceMemoryProperties physicalDeviceMemoryProperty;
}DeviceContext;

typedef struct TrainCache
{
    uint32_t batchSize;
    VkPipelineLayout pipelineLayout;
    StorageBuffer feature, error, delta, label;
}TrainCache;

#ifdef __cplusplus
extern "C" {
#endif


    VkResult CreateDeviceContext(DeviceContext* ctx);

    void DestroyDeviceContext(DeviceContext* ctx);

    VkResult CreateTrainCache(DeviceContext* ctx, TrainCache* cache, const uint32_t batchSize);

    void DestroyTrainCache(DeviceContext* ctx, TrainCache* cache);

    void LoadModel(DeviceContext* lenet, LeNet5* data);

    void SaveModel(DeviceContext* lenet, LeNet5* data);

    uint32_t Predict(DeviceContext* ctx, Feature* feature);

    void TrainBatch(DeviceContext* ctx, TrainCache* cache, Feature* feature, uint32_t* label);

#ifdef __cplusplus
}
#endif

#endif
